import copy

# region youmi
# 因为openai本地包和线上包重名，所以这里先删除本地包，然后再导入
import sys
from pprint import pformat
from typing import Dict, Iterator, List, Optional

from rtp_llm.config.py_config_modules import StaticConfig

path_index_list = -1
path_remove = ""
for idx, path in enumerate(sys.path):
    if path.endswith("rtp_llm/rtp_llm"):
        path_index_list = idx
        break
if path_index_list > -1:
    path_remove = sys.path.pop(path_index_list)
import openai

if path_remove:
    sys.path.insert(path_index_list, path_remove)
# end region

if openai.__version__.startswith("0."):
    from openai.error import OpenAIError  # noqa
else:
    from openai import OpenAIError

from qwen_agent.llm.base import ModelServiceError, register_llm
from qwen_agent.llm.text_base import BaseTextChatModel
from qwen_agent.log import logger

from .schema import ASSISTANT, Message


@register_llm("oai")
class TextChatAtOAI(BaseTextChatModel):

    def __init__(self, cfg: Optional[Dict] = None):
        super().__init__(cfg)
        self.model = self.model or "gpt-3.5-turbo"
        cfg = cfg or {}

        api_base = cfg.get(
            "api_base",
            cfg.get(
                "base_url",
                cfg.get("model_server", ""),
            ),
        ).strip()

        api_key = cfg.get("api_key", "")
        if not api_key:
            api_key = StaticConfig.model_config.openai_api_key
        api_key = api_key.strip()

        if openai.__version__.startswith("0."):
            if api_base:
                openai.api_base = api_base
            if api_key:
                openai.api_key = api_key
            self._chat_complete_create = openai.ChatCompletion.create
        else:
            api_kwargs = {}
            if api_base:
                api_kwargs["base_url"] = api_base
            if api_key:
                api_kwargs["api_key"] = api_key

            def _chat_complete_create(*args, **kwargs):
                # OpenAI API v1 does not allow the following args, must pass by extra_body
                extra_params = ["top_k", "repetition_penalty"]
                if any((k in kwargs) for k in extra_params):
                    kwargs["extra_body"] = copy.deepcopy(kwargs.get("extra_body", {}))
                    for k in extra_params:
                        if k in kwargs:
                            kwargs["extra_body"][k] = kwargs.pop(k)
                if "request_timeout" in kwargs:
                    kwargs["timeout"] = kwargs.pop("request_timeout")

                client = openai.OpenAI(**api_kwargs)
                return client.chat.completions.create(*args, **kwargs)

            self._chat_complete_create = _chat_complete_create

    def _chat_stream(
        self,
        messages: List[Message],
        delta_stream: bool,
        generate_cfg: dict,
    ) -> Iterator[List[Message]]:
        messages = [msg.model_dump() for msg in messages]
        logger.debug(f"*{pformat(messages, indent=2)}*")
        try:
            response = self._chat_complete_create(
                model=self.model, messages=messages, stream=True, **generate_cfg
            )
            if delta_stream:
                for chunk in response:
                    if (
                        hasattr(chunk.choices[0].delta, "content")
                        and chunk.choices[0].delta.content
                    ):
                        yield [Message(ASSISTANT, chunk.choices[0].delta.content)]
            else:
                full_response = ""
                for chunk in response:
                    if (
                        hasattr(chunk.choices[0].delta, "content")
                        and chunk.choices[0].delta.content
                    ):
                        full_response += chunk.choices[0].delta.content
                        yield [Message(ASSISTANT, full_response)]
        except OpenAIError as ex:
            raise ModelServiceError(exception=ex)

    def _chat_no_stream(
        self,
        messages: List[Message],
        generate_cfg: dict,
    ) -> List[Message]:
        messages = [msg.model_dump() for msg in messages]
        logger.debug(f"*{pformat(messages, indent=2)}*")
        try:
            response = self._chat_complete_create(
                model=self.model, messages=messages, stream=False, **generate_cfg
            )
            return [Message(ASSISTANT, response.choices[0].message.content)]
        except OpenAIError as ex:
            raise ModelServiceError(exception=ex)
